import java.util.*;

public class _508 {
    //递归写法
    static class Solution1 {
        public int[] findFrequentTreeSum(TreeNode root) {
            //  后序遍历，统计每个节点和的数量
            Map<Integer, Integer> sumCount = new HashMap<>();
            postOrder(root, sumCount);
            TreeMap<Integer, List<Integer>> map = new TreeMap<>();
            for (Map.Entry<Integer, Integer> e : sumCount.entrySet()) {
                map.computeIfAbsent(e.getValue(), k -> new ArrayList<>()).add(e.getKey());
            }
            List<Integer> resList = map.lastEntry().getValue();
            int[] res = new int[resList.size()];
            int idx = 0;
            for (int val : resList) {
                res[idx++] = val;
            }
            return res;
        }

        public int postOrder(TreeNode root, Map<Integer, Integer> sumCount) {
            if (root == null) return 0;
            int cur = root.val + postOrder(root.left, sumCount) + postOrder(root.right, sumCount);
            int count = sumCount.getOrDefault(cur, 0);
            sumCount.put(cur, count + 1);
            return cur;
        }
    }

    //非递归写法难度很高
    static class Solution2 {
        public int[] findFrequentTreeSum(TreeNode root) {
            if (root == null) return new int[0];
            Deque<TreeNode> dq = new ArrayDeque<>();
            Map<Integer, Integer> map = new HashMap<>();
            Map<TreeNode, Integer> cache = new HashMap<>();//缓存左右子树的和
            TreeNode lastPop = new TreeNode();
            int lastSum = 0;
            //左节点不为空入栈,当前节点为结点
            //右节点不为空入栈，当前节点为右结点
            //上个弹出的为右结点，下个弹出的为当前结点
            //上个弹出的节点为左结点，且右节点为空，弹出当前结点,右结点不为空，右结点入栈
            dq.push(root);
            while (!dq.isEmpty()) {
                TreeNode top = dq.peek();
                if ((top.left == lastPop && top.right == null) || (top.right == lastPop) || (top.left == null && top.right == null)) {
                    lastPop = dq.pop();
                    lastSum = lastPop.val + cache.getOrDefault(lastPop, 0);
                    map.put(lastSum, map.getOrDefault(lastSum, 0) + 1);
                    TreeNode next = dq.peek();
                    if (next != null) {
                        int last = cache.getOrDefault(next, 0);
                        cache.put(next, last + lastSum);
                    }
                    continue;
                }
                if (top.left != null && top.left != lastPop) {
                    dq.push(top.left);
                } else if ((top.left == lastPop || top.left == null) && top.right != null) {
                    dq.push(top.right);
                }


            }
            int maxVal = 0;
            for (int value : map.values()) {
                maxVal = Math.max(value, maxVal);
            }
            List<Integer> res = new ArrayList<>();
            for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
                if (entry.getValue().equals(maxVal)) {
                    res.add(entry.getKey());
                }
            }
            int[] r = new int[res.size()];
            int idx = 0;
            for (int v : res) {
                r[idx++] = v;
            }
            return r;
        }
    }
}
